library(pacman)
p_load(
  here, fst, data.table, ggplot2, magrittr, janitor, stringr, purrr,
  collapse, patchwork
)
theme_set(theme_minimal(base_size = 14))

# Loading data ----------------------------------------------------------------
# Marginal production model
model_dt = 
  fread(
    here("Data/electricity-generation/output/PPRegressors-oge.csv")
  )[,plant_id_eia := as.character(orispl_code)] |> 
  setkey(plant_id_eia)
# Creating a long version (so rows are coefs, does not have intercept)
model_dt_long = 
  melt(
    model_dt, 
    id.vars = 'plant_id_eia',
    measure.vars = patterns('excess_load_')
  )[value != 0,.(
    plant_id_eia, 
    nerc_adj = str_remove_all(variable, 'excess_load_|_sq'),
    square = str_detect(variable, '_sq$'),
    coef = value
  )] |>
  merge( # Adding the estimated excess load where max/min happens
    melt(
      model_dt,
      id.vars = 'plant_id_eia', 
      measure.vars = patterns('est_extremum')
    )[!is.na(value),.(
      plant_id_eia, 
      nerc_adj = str_remove(variable, 'est_extremum_'),
      est_extremum = value
    )],
    by = c('plant_id_eia','nerc_adj'),
    all.x = TRUE
  )
# Some name cleaning 
setnames(
  model_dt, 
  old = str_subset(colnames(model_dt), 'sq'),
  new = paste0(
    'excess_load_sq_',
    str_subset(colnames(model_dt), 'sq') |>
      str_remove('_sq') |>
      str_remove('excess_load_')
  ))
setnames(
  model_dt, 
  old = str_subset(colnames(model_dt), 'excess_load'),
  new = paste0('coef_',str_subset(colnames(model_dt), 'excess_load'))
)
# Load data
nerc_load_dt = 
  fread(
    here("Data/electricity-generation/output/RegLoad-oge.csv")
  )[,.(
    nerc_adj = tolower(nerc_adj), 
    datetime_utc = utc_time, 
    excess_load, 
    excess_load_sq = excess_load^2
  )]
# Emissions rates
emissions_rate_dt = 
  read.fst(
    path = here('Data/electricity-generation/output/emissions-md-dt.fst'),
    as.data.table = TRUE
  )[,plant_id_eia := as.character(plant_id_eia)] |>
  dcast(
    plant_id_eia + net_generation_mwh_split ~ output + pollutant,
    value.var = c('emissions_rate_tons')
  )
# Marginal damages: Using SCC 148 (in 2014 USD == $185 in 2020)
tot_md_dt = fread(
  here('Data/electricity-generation/output/damage-per-mwh-oge-scc-148.csv')
)[,plant_id_eia := as.character(plant_id_eia)] 
# Avg solar panel production by region (From Mark email 10/17/2022)
# panel_output_dt = data.table(
#   nerc_adj = c('cal','mro','npcc','rfc','serc','tre','wecc'),	
#   avg_output_mwh = c(5935.968, 4802.954, 4546.329, 4584.62, 5161.415, 5274.348, 5444.394)/1000
# )
# Time profile for solar production by state
# state_solar_profile_dt = fread(
#   here("Data/electricity-generation/output/solar-profile-dt.csv")
# )
# Crosswalk between states and nerc regions 
# nerc_state_xwalk = data.table(
#   nerc_adj = c('cal','mro','npcc','rfc','serc','tre','wecc'),
#   state_fips = c(6, 29, 25, 39, 13, 48, 49)
# )
# Plant info table
plant_dt =
  #rbind(
  #  read.fst(
  #    path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
  #    as.data.table = TRUE
  #  )[, 
  #    .SD[which.min(year)], 
  #    by = .(plant_id_eia = as.character(plant_id_eia))
  #  ],
  #  read.fst(
  #    path = here("Data/electricity-generation/shaped-info-dt-oge.fst"),
  #    as.data.table = TRUE
  #  )[, 
  #    .SD[which.min(year)], 
  #    by = .(plant_id_eia = as.character(plant_id_eia))
  #  ],
  #  fill = TRUE
  #) |>
  read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
  )[year == 2019,.(plant_id_eia = as.character(plant_id_eia), namepcap)] |>
  merge(model_dt, by = 'plant_id_eia')|>
  merge(emissions_rate_dt, by = 'plant_id_eia') |>
  merge(tot_md_dt, by = 'plant_id_eia') %>%
  .[,.(
    plant_id_eia, 
    namepcap, 
    log_scale, 
    intercept,
    net_generation_mwh_split = net_generation_mwh_split.x, 
    md_per_mwh_high, 
    md_per_mwh_low,
    emissions_rate_low_co2e,
    emissions_rate_high_co2e,
    emissions_rate_low_nox,
    emissions_rate_high_nox,
    emissions_rate_low_so2,
    emissions_rate_high_so2,
    emissions_rate_low_pm25,
    emissions_rate_high_pm25
  )] |>
  setkey(plant_id_eia)


# Function to propotionately split total excess load 
eload_proportinal = function(
  nerc_load_dt,
  quants = seq(0,1,by = 0.01),
  type = 'ic',
  manual_eload_limits = NULL
){
  if(type == 'ic'){
    # Total load max and min
    tot_eload_lim =
      nerc_load_dt[,
        .(tot_load = sum(excess_load)), 
        by = .(
          ic = fcase(
            nerc_adj %in% c('cal', 'wecc'), 'west',
            nerc_adj == 'tre', 'texas',
            default = 'east'
          ),
          datetime_utc
        )
      ][,.(
        max_eload = max(tot_load), 
        min_eload = min(tot_load)), 
        keyby = ic
      ]
    # Percent of total on average for each region
    nerc_adj_wts = 
      nerc_load_dt[,.(
        tot_region_load = sum(excess_load)), 
        keyby = .(
          ic = fcase(
            nerc_adj %in% c('cal', 'wecc'), 'west',
            nerc_adj == 'tre', 'texas',
            default = 'east'
          ),
          nerc_adj
        )
      ][,
        pct_load := tot_region_load/sum(tot_region_load), 
        by = ic
      ]
    # Scaling from min to max by the quants
    tot_eload_dt =
      tot_eload_lim[,
        as.list(data.table(quantile = quants)), 
        by = tot_eload_lim
      ][,tot_eload := min_eload + quantile*(max_eload-min_eload)]
    # Making load table 
    load_dt_raw = 
      merge(
        tot_eload_dt, 
        nerc_adj_wts,
        by = 'ic',
        allow.cartesian = TRUE
      )[,.(
        nerc_adj, 
        quantile,
        tot_eload,
        square = FALSE,
        excess_load = tot_eload*pct_load
      )]
  }else if(type == 'all'){
    # Total load max and min
    tot_eload_lim =
      nerc_load_dt[,
        .(tot_load = sum(excess_load)), 
        by = .(
          ic = fcase(
            nerc_adj %in% c('cal', 'wecc'), 'west',
            nerc_adj == 'tre', 'texas',
            default = 'east'
          ),
          datetime_utc
        )
      ][,.(
        max_eload = max(tot_load), 
        min_eload = min(tot_load)
      )]
    # Manually changing limits if given
    if(!is.null(manual_eload_limits)){
      tot_eload_lim[,':='(
        min_eload = manual_eload_limits[1],
        max_eload = manual_eload_limits[2]
      )]
    }
    # Percent of total on average for each region
    nerc_adj_wts = 
      nerc_load_dt[,.(
        tot_region_load = sum(excess_load)), 
        keyby = nerc_adj
      ][,pct_load := tot_region_load/sum(tot_region_load)]
    # Scaling from min to max by the quants
    tot_eload_dt =
      tot_eload_lim[,
        as.list(data.table(quantile = quants)), 
        by = tot_eload_lim
      ][,tot_eload := min_eload + quantile*(max_eload-min_eload)]
    # Making load table 
    load_dt_raw = 
      tot_eload_dt[, 
        as.list(nerc_adj_wts),
        by = tot_eload_dt,
      ][,.(
        nerc_adj, 
        quantile,
        tot_eload,
        square = FALSE,
        excess_load = tot_eload*pct_load
      )]
  }else if(type == 'quants'){
    load_dt_raw = 
      map_dfr(
        unique(nerc_load_dt$nerc_adj),
        \(nerc){
          data.table(
            nerc_adj = nerc, 
            quantile = quants,
            square = FALSE,
            excess_load = quantile(
              nerc_load_dt[nerc_adj == nerc]$excess_load, 
              probs = quants
            )
          )
        }
      )
    load_dt_raw[,tot_eload := sum(excess_load), by = quantile]
  }else{
    return("Incorrect type")
  }
  # Deviating each region by 1mwh at every quantile 
  load_dt_adj = 
    load_dt_raw[,
      as.list(data.table(deviation = c('baseline',unique(nerc_load_dt$nerc_adj)))),
      by = load_dt_raw
    ][,excess_load := fcase(
      deviation == nerc_adj, excess_load + 1,
      deviation == 'baseline' | deviation != nerc_adj, excess_load,
      default = NA
    )]
  # Adding squared term
  load_dt = 
    rbind(
      load_dt_adj,
      load_dt_adj[,.(
        nerc_adj, 
        quantile, 
        tot_eload, 
        square = TRUE, 
        excess_load = excess_load^2,
        deviation
      )]
    )
  return(load_dt)
}


# Calculating marginal damages
  # md_dt = 
  #   dcast(
  #     plant_gen_dt, 
  #     formula = plant_id_eia + quantile ~ deviation, 
  #     value.var = 'damages'
  #   )[,.(
  #     plant_id_eia, quantile, 
  #     md_cal = cal - baseline,
  #     md_mro = mro - baseline,
  #     md_npcc = npcc - baseline,
  #     md_rfc = rfc - baseline,
  #     md_serc = serc - baseline,
  #     md_tre = tre - baseline,
  #     md_wecc = wecc - baseline
  #   )]
  # # Calculating marginal production
  # mp_dt = 
  #   dcast(
  #     plant_gen_dt, 
  #     formula = plant_id_eia + quantile ~ deviation, 
  #     value.var = 'fit_net_generation_mwh'
  #   )[,.(
  #     plant_id_eia, quantile, 
  #     mp_cal = cal - baseline,
  #     mp_mro = mro - baseline,
  #     mp_npcc = npcc - baseline,
  #     mp_rfc = rfc - baseline,
  #     mp_serc = serc - baseline,
  #     mp_tre = tre - baseline,
  #     mp_wecc = wecc - baseline
  #   )]
  # # adding back together 
  # out_dt = 
  #   merge(
  #     plant_gen_dt[deviation == 'baseline',-'deviation'],
  #     md_dt, 
  #     by = c('plant_id_eia','quantile')
  #   ) |>
  #   merge(
  #     mp_dt, 
  #     by = c('plant_id_eia','quantile')
  #   )

sim_eload_quants = function(plant_id_eia_in, load_dt, n_eps_draws = 10000){
  # Drawing epsilons 
  epsilon_dt = 
    generate_epsilons(
      plant_dt = plant_dt[plant_id_eia == plant_id_eia_in], 
      #vec_in = unique(load_dt$quantile),
      col_name = 'quantile',
      n_each = n_eps_draws
    ) 
  # Calculating production
  quant_plant_dt = 
    predict_production(
      plant_id_eia_in = plant_id_eia_in,
      load_dt = load_dt, 
      plant_dt = plant_dt, 
      model_dt_long = model_dt_long, 
      epsilon_dt = epsilon_dt
    )    
}

# Function to create table of epsilons for each plant 
generate_epsilons = function(plant_dt, n_each = 1000){
  epsilon_dt = 
    map_dfr(
      plant_dt$plant_id_eia,
      \(id){
        data.table(
          plant_id_eia = as.character(id),
          epsilon_id = 1:n_each,
          epsilon = rnorm(
            n_each, mean = 0, 
            sd = exp(plant_dt[plant_id_eia == as.character(id)]$log_scale)
          )
        )
      }
    ) |> 
    setkey(plant_id_eia) 
  return(epsilon_dt)
}


# Getting load for each quantile 
quant_load_dt = eload_proportinal(nerc_load_dt, type = 'quants')
prop_load_dt = eload_proportinal(nerc_load_dt, type = 'ic')
all_prop_load_dt = eload_proportinal(nerc_load_dt, type = 'all')
man_prop_load_dt = eload_proportinal(
  nerc_load_dt, 
  type = 'all', 
  manual_eload_limits = c(0,5e5)
)
actual_load_dt = 
  rbind(
    nerc_load_dt[,.(nerc_adj, quantile = datetime_utc, tot_eload = NA, square = FALSE, excess_load)],
    nerc_load_dt[,.(nerc_adj, quantile = datetime_utc, tot_eload = NA, square = TRUE, excess_load = excess_load_sq)]
  )

set.seed(218)
# Running it!
quant_plant_dt = 
  map_dfr(
    plant_dt$plant_id_eia,
    sim_eload_quants,
    load_dt = quant_load_dt
  )
write.fst(
  x = quant_plant_dt, 
  path = here('Data/electricity-generation/output/quant-plant-dt.fst')
)
set.seed(218)
prop_plant_dt = 
  map_dfr(
    plant_dt$plant_id_eia,
    sim_eload_quants,
    load_dt = prop_load_dt
  )
write.fst(
  x = prop_plant_dt, 
  path = here('Data/electricity-generation/output/prop-plant-dt.fst')
)
set.seed(218)
all_plant_dt = 
  map_dfr(
    plant_dt$plant_id_eia,
    sim_eload_quants,
    load_dt = all_prop_load_dt
  )
write.fst(
  x = all_plant_dt, 
  path = here('Data/electricity-generation/output/all-plant-dt.fst')
)
set.seed(218)
man_plant_dt = 
  map_dfr(
    plant_dt$plant_id_eia,
    sim_eload_quants,
    load_dt = man_prop_load_dt
  )
write.fst(
  x = man_plant_dt, 
  path = here('Data/electricity-generation/output/man-plant-dt.fst')
)
set.seed(218)
actual_plant_dt = 
  map_dfr(
    plant_dt$plant_id_eia,
    sim_eload_quants,
    load_dt = actual_load_dt
  )
write.fst(
  x = actual_plant_dt, 
  path = here('Data/electricity-generation/output/actual-plant-dt.fst')
)


# Choose the type of simulation
sim_plant_dt_old =
  read.fst(
    path = here(paste0('Data/electricity-generation/output/quant-plant-dt.fst')),
    as.data.table = TRUE
  )

# Now we get to plot things 
plant_info_dt =
  rbind(
    read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
    )[, 
      .SD[which.min(year)], 
      by = .(plant_id_eia = as.character(plant_id_eia))
    ],
    read.fst(
      path = here("Data/electricity-generation/shaped-info-dt-oge.fst"),
      as.data.table = TRUE
    )[, 
      .SD[which.min(year)], 
      by = .(plant_id_eia = as.character(plant_id_eia))
    ],
    fill = TRUE
  )[,.(plant_id_eia, fuel_category, nerc_adj)]
# Merging together and adding total load 
sim_plant_dt =  
  merge(
    sim_plant_dt[,-c('fuel_category','nerc_adj')],
    plant_info_dt, 
    by = 'plant_id_eia'
  )[,ic := 
      fcase(
        nerc_adj == 'TRE', 'Texas',
        nerc_adj %in% c("CAL",'WECC'), 'West',
        default = 'East'
      ) |> 
      factor(levels =c('West','Texas','East'))
  ] 

ic_props = 
  man_prop_load_dt[
    square == FALSE 
    & deviation == 'baseline'
    & quantile == 1,.(
    nerc_adj = toupper(nerc_adj), 
    ic = 
      fcase(
        toupper(nerc_adj) == 'TRE', 'Texas',
        toupper(nerc_adj) %in% c("CAL",'WECC'), 'West',
        default = 'East'
      ) |> 
      factor(levels =c('West','Texas','East')),
    quantile,
    excess_load ,
    tot_eload
  )][,.(
    prop_ic = sum(excess_load)/median(tot_eload)), 
    by = .(ic)
  ]
# Summarizing by region and fuel type 
region_sim_dt = 
  sim_plant_dt[,.(
    tot_gen_mwh = sum(fit_net_generation_mwh),
    tot_co2e_tons = sum(fit_emissions_tons_co2e),
    tot_nox_tons = sum(fit_emissions_tons_nox),
    tot_so2_tons = sum(fit_emissions_tons_so2),
    tot_pm25_tons = sum(fit_emissions_tons_pm25),
    tot_damages = sum(damages),
    co2e_tons_per_mwh = sum(fit_emissions_tons_co2e)/sum(fit_net_generation_mwh),
    nox_tons_per_mwh  = sum(fit_emissions_tons_nox)/sum(fit_net_generation_mwh),
    so2_tons_per_mwh  = sum(fit_emissions_tons_so2)/sum(fit_net_generation_mwh),
    pm25_tons_per_mwh = sum(fit_emissions_tons_pm25)/sum(fit_net_generation_mwh),
    damages_per_mwh = sum(damages)/sum(fit_net_generation_mwh)), 
    keyby = .(nerc_adj, ic, quantile, tot_eload)
  ]
ic_sim_dt = 
  sim_plant_dt[,.(
    tot_gen_mwh = sum(fit_net_generation_mwh),
    tot_co2e_tons = sum(fit_emissions_tons_co2e),
    tot_nox_tons = sum(fit_emissions_tons_nox),
    tot_so2_tons = sum(fit_emissions_tons_so2),
    tot_pm25_tons = sum(fit_emissions_tons_pm25),
    tot_damages = sum(damages),
    co2e_tons_per_mwh = sum(fit_emissions_tons_co2e)/sum(fit_net_generation_mwh),
    nox_tons_per_mwh  = sum(fit_emissions_tons_nox)/sum(fit_net_generation_mwh),
    so2_tons_per_mwh  = sum(fit_emissions_tons_so2)/sum(fit_net_generation_mwh),
    pm25_tons_per_mwh = sum(fit_emissions_tons_pm25)/sum(fit_net_generation_mwh),
    damages_per_mwh = sum(damages)/sum(fit_net_generation_mwh)), 
    keyby = .(
      nerc_adj = fcase(
        nerc_adj == 'TRE', 'Texas',
        nerc_adj %in% c("CAL",'WECC'), 'West',
        default = 'East'
      ) |> factor(levels =c('West','Texas','East')), 
      quantile,
      tot_eload
    )
  ]
md_sim_dt = 
  sim_plant_dt[,
    lapply(.SD, sum),
    keyby = .(quantile, tot_eload),
    .SDcols = str_subset(colnames(sim_plant_dt),'md|mp')
  ] |> melt(
    id.vars = c('quantile','tot_eload')
  ) %>%
  .[,.(
    nerc_adj = toupper(str_remove(variable, 'm(d|p)_')),
    variable = str_extract(variable, 'm(d|p)'),
    ic = fcase(
        toupper(str_remove(variable, 'm(d|p)_')) == 'TRE', 'Texas',
        toupper(str_remove(variable, 'm(d|p)_')) %in% c("CAL",'WECC'), 'West',
        default = 'East'
      ) |> factor(levels =c('West','Texas','East')),
    quantile, 
    tot_eload,
    value
  )] |>
  dcast(
    formula = nerc_adj + ic + quantile + tot_eload ~ variable, 
    value.var = 'value'
  ) |>
  # IF USING QUANT_LOAD_DT
  merge(
    quant_load_dt[
      square == FALSE & deviation == 'baseline', 
      .(tot_eload_ic = sum(excess_load)),
      by =.(
        quantile, 
        ic = fcase(
            toupper(nerc_adj) == 'TRE', 'Texas',
            toupper(nerc_adj) %in% c("CAL",'WECC'), 'West',
            default = 'East'
          ) |> factor(levels =c('West','Texas','East'))
        )
    ],
    by = c('ic','quantile')
  ) %>%
  .[,':='(
    nerc_adj_f = factor(
      nerc_adj,
      levels = c('CAL','WECC','TRE','MRO','NPCC','RFC','SERC')
    )
  )]
  # IF USING MAN_PROP_LOAD_DT
  #merge(
  #  ic_props, 
  #  by = 'ic'
  #) %>% 
  #.[,':='(
  #  tot_eload_ic = tot_eload*prop_ic,
  #  prop_ic = NULL,
  #  nerc_adj_f = factor(
  #    nerc_adj, 
  #    levels = c('CAL','WECC','TRE','MRO','NPCC','RFC','SERC')
  #  )
  #)]



# Plots for the paper ---------------------------------------------------------
region_md_demand_p = 
  ggplot(md_sim_dt[quantile %between% c(0.00,1.00)], aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_f, linetype = nerc_adj_f)) + 
  geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','twodash','dashed','longdash','solid','dashed','solid')
    )+
    labs(
      x = 'Excess Load (GWh)',
      y = 'Marginal Damage ($/MWh)'
    ) +
    facet_grid(cols = vars(ic), scales = 'free_x')+ 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(3, 'cm')
    )
  region_md_demand_p
  ggsave(
      plot = region_md_demand_p,
      filename = here('figures/electricity-generation/eload-sim-region-md-demand.jpeg'),
      bg = 'white',
      width = 12, height = 5
  )
region_mp_demand_p = 
  ggplot(md_sim_dt, aes(x = tot_eload_ic/1e3, y = mp, color = nerc_adj_f, linetype = nerc_adj_f)) + 
  geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','twodash','dashed','longdash','solid','dashed','solid')
    )+
    labs(
      x = 'Excess Load (GWh)',
      y = 'Marginal Production'
    ) +
    facet_grid(cols = vars(ic), scales = 'free_x')+ 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(3, 'cm')
    )
  region_mp_demand_p
  ggsave(
      plot = region_mp_demand_p,
      filename = here('figures/electricity-generation/eload-sim-region-mp-demand.jpeg'),
      bg = 'white',
      width = 12, height = 5
  )

  # Density of load
  eload_sim_ic_hist_p = 
    nerc_load_dt[,
      .(tot_eload_ic = sum(excess_load)/1e3),
      keyby = .(
        datetime_utc, 
        ic = fcase(
          nerc_adj == 'tre','Texas',
          nerc_adj %in% c('cal', 'wecc'), 'West',
          default = 'East'
        ) |> factor(levels = c('West','Texas','East'))
      )
    ] |>
    ggplot(aes(x =  tot_eload_ic)) + 
    geom_histogram(bins = 70) + 
    facet_wrap(~ic, scales = 'free_x') + 
    labs(
      x = 'Excess Load (GWh)',
      y = 'Number of Hours'
    )
  ggsave(
      plot = eload_sim_ic_hist_p,
      filename = here('figures/electricity-generation/eload-sim-ic-hist-p.jpeg'),
      bg = 'white',
      width = 12, height = 5
  )

  # Make 3 plots to get legend underneath 
  md_pW = 
    ggplot(
      md_sim_dt[quantile %between% c(0.0,1.00) & ic == 'West'], 
      aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_f, linetype = nerc_adj_f)
    ) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','twodash','dashed','longdash')
    ) +
    labs(
      x = '',
      y = 'Marginal Damage ($/MWh)',
      title = 'West'
    ) +
    ylim(29,112) + 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(2.25, 'cm'),
      plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        )
    ) + 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
    md_pT = 
      ggplot(
        md_sim_dt[quantile %between% c(0.0,1.00) & ic == 'Texas'], 
        aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_f, linetype = nerc_adj_f)
      ) + 
      geom_line(size = 1.25) +
      scale_color_manual(
        name = '',
        values = RColorBrewer::brewer.pal(7, "Dark2")[3]
      ) +
      scale_linetype_manual(
        name = '',
        values = c('solid')#,'twodash','dashed','longdash')
      ) +
      labs(
        x = 'Excess Load (GWh)',
        title = 'Texas'
      ) +
      ylim(29,112) + 
      theme(
        legend.position = 'bottom',
        legend.key.width = unit(2.25, 'cm'),
        plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        ),
        axis.title.y = element_blank(),
        axis.text.y = element_blank()
      ) + 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
    md_pE = 
      ggplot(
        md_sim_dt[quantile %between% c(0.0,1.00) & ic == 'East'], 
        aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_f, linetype = nerc_adj_f)
      ) + 
      geom_line(size = 1.25) +
      scale_color_manual(
        name = '',
        values = RColorBrewer::brewer.pal(7, "Dark2")[4:7]
      ) + 
      scale_linetype_manual(
        name = '',
        values = c('solid','twodash','dashed','longdash')
      ) +
      labs(
        x = '',
        title = 'East'
      ) +
      ylim(29,112) + 
      theme(
        legend.position = 'bottom',
        legend.key.width = unit(2.25, 'cm'),
        plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        ),
        axis.title.y = element_blank(),
        axis.text.y = element_blank()
      )+ 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
  # Adding all three plots together 
  ggsave(
      plot = md_pW + md_pT + md_pE,
      filename = here('figures/electricity-generation/eload-sim-region-md-mwh-demand.jpeg'),
      bg = 'white',
      width = 12, height = 5
  )
  

# How much is consumed within region vs exported 
md_region_sim_dt = 
  sim_plant_dt[,
    lapply(.SD, sum),
    keyby = .(quantile, tot_eload, nerc_adj),
    .SDcols = str_subset(colnames(sim_plant_dt),'md|mp')
  ] |> melt(
    id.vars = c('quantile','tot_eload','nerc_adj')
  ) %>%
  .[,.(
    nerc_adj_prod = nerc_adj,
    nerc_adj_demand = toupper(str_remove(variable, 'm(d|p)_')),
    variable = str_extract(variable, 'm(d|p)'),
    ic = fcase(
        toupper(str_remove(variable, 'm(d|p)_')) == 'TRE', 'Texas',
        toupper(str_remove(variable, 'm(d|p)_')) %in% c("CAL",'WECC'), 'West',
        default = 'East'
      ) |> factor(levels =c('West','Texas','East')),
    quantile, 
    tot_eload,
    value
  )] |>
  dcast(
    formula = nerc_adj_prod + nerc_adj_demand + ic + quantile + tot_eload ~ variable, 
    value.var = 'value'
  ) |>
  # IF USING QUANT_LOAD_DT
  merge(
    quant_load_dt[
      square == FALSE & deviation == 'baseline', 
      .(tot_eload_ic = sum(excess_load)),
      by =.(
        quantile, 
        ic = fcase(
            toupper(nerc_adj) == 'TRE', 'Texas',
            toupper(nerc_adj) %in% c("CAL",'WECC'), 'West',
            default = 'East'
          ) |> factor(levels =c('West','Texas','East'))
        )
    ],
    by = c('ic','quantile')
  ) %>%
  .[,':='(
    nerc_adj_prod = factor(
      nerc_adj_prod,
      levels = c('CAL','WECC','TRE','MRO','NPCC','RFC','SERC')
    ),
    nerc_adj_demand = factor(
      nerc_adj_demand,
      levels = c('CAL','WECC','TRE','MRO','NPCC','RFC','SERC')
    )
  )] %>%
  .[,share_of_mp := mp/sum(mp), by=.(quantile,nerc_adj_demand)]


# PLOT FOR PAPER: DEMAND FILLED BUT ONLY IN OWN REGION
md_pW_own = 
    ggplot(
      md_region_sim_dt[nerc_adj_prod == nerc_adj_demand & ic == 'West'], 
      aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_demand, linetype = nerc_adj_demand)
    ) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','twodash','dashed','longdash')
    ) +
    labs(
      x = '',
      y = 'Marginal Damage ($/MWh)',
      title = 'West'
    ) +
    ylim(25,125) + 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(2.25, 'cm'),
      plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        )
    ) + 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
    md_pT_own = 
      ggplot(
        md_region_sim_dt[nerc_adj_prod == nerc_adj_demand & ic == 'Texas'], 
        aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_demand, linetype = nerc_adj_demand)
      ) + 
      geom_line(size = 1.25) +
      scale_color_manual(
        name = '',
        values = RColorBrewer::brewer.pal(7, "Dark2")[3]
      ) +
      scale_linetype_manual(
        name = '',
        values = c('solid','twodash','dashed','longdash')
      ) +
      labs(
        x = 'Excess Load (GWh)',
        title = 'Texas'
      ) +
      ylim(25,125) + 
      theme(
        legend.position = 'bottom',
        legend.key.width = unit(2.25, 'cm'),
        plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        ),
        axis.title.y = element_blank(),
        axis.text.y = element_blank()
      ) + 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
    md_pE_own = 
      ggplot(
        md_region_sim_dt[nerc_adj_prod == nerc_adj_demand & ic == 'East'], 
        aes(x = tot_eload_ic/1e3, y = md/mp, color = nerc_adj_demand, linetype = nerc_adj_demand)
      ) + 
      geom_line(size = 1.25) +
      scale_color_manual(
        name = '',
        values = RColorBrewer::brewer.pal(7, "Dark2")[4:7]
      ) + 
      scale_linetype_manual(
        name = '',
        values = c('solid','twodash','dashed','longdash')
      ) +
      labs(
        x = '',
        title = 'East'
      ) +
      ylim(25,125) + 
      theme(
        legend.position = 'bottom',
        legend.key.width = unit(2.25, 'cm'),
        plot.title = element_text(
          hjust = 0.5,
          color = 'grey10',
          size = rel(1.0)
        ),
        axis.title.y = element_blank(),
        axis.text.y = element_blank()
      )+ 
    guides(
      color=guide_legend(nrow=2,byrow=TRUE),
      linetype=guide_legend(nrow=2,byrow=TRUE)
    )
  # Adding all three plots together 
  ggsave(
      plot = md_pW_own + md_pT_own + md_pE_own,
      filename = here('figures/electricity-generation/eload-sim-region-md-mwh-own-demand.jpeg'),
      bg = 'white',
      width = 12, height = 5
  )
  



  # Share of marginal product from own region
  ggplot(md_region_sim_dt[nerc_adj_prod == nerc_adj_demand], aes(x = tot_eload_ic/1e3, y = share_of_mp, color = nerc_adj_demand, linetype = nerc_adj_demand)) + 
  geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','twodash','dashed','longdash','solid','dashed','solid')
    )+
    labs(
      x = 'Excess Load (GWh)',
      y = 'Share of Marg. Prod. from own region'
    ) +
    facet_grid(cols = vars(ic), scales = 'free_x')+ 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(3, 'cm')
    )

  # Checking how many plants are at their nameplate capacity 
  cap_dt = 
    merge(
      sim_plant_dt,
      plant_dt[,.(plant_id_eia, namepcap)],
      by = 'plant_id_eia'
    )[,.(
      quantile, 
      tot_eload,
      nerc_adj, 
      ic,
      at_capacity = fit_net_generation_mwh/namepcap > 0.99,
      not_producing = fit_net_generation_mwh <= 0,
      remaining_capacity = namepcap - fit_net_generation_mwh
    )][,.(
      at_capacity = mean(at_capacity),
      not_producing = mean(not_producing),
      remaining_capacity = sum(remaining_capacity)), 
      keyby =.(quantile, ic, nerc_adj)
    ]|>
    # IF USING QUANT_LOAD_DT
    merge(
      quant_load_dt[
        square == FALSE & deviation == 'baseline', 
        .(tot_eload_ic = sum(excess_load)),
        by =.(
          quantile, 
          ic = fcase(
              toupper(nerc_adj) == 'TRE', 'Texas',
              toupper(nerc_adj) %in% c("CAL",'WECC'), 'West',
              default = 'East'
            ) |> factor(levels =c('West','Texas','East'))
          )
      ],
      by = c('ic','quantile')
    ) %>%
    .[,':='(
      nerc_adj_f = factor(
        nerc_adj,
        levels = c('CAL','WECC','TRE','MRO','NPCC','RFC','SERC')
      )
    )]

  ggplot(
    cap_dt, 
    aes(
      x = tot_eload_ic/1e3, 
      y = remaining_capacity/1e3, #at_capacity, 
      color = nerc_adj_f, 
      linetype = nerc_adj_f
    )
  ) + 
  geom_line(size = 1.25) +
  scale_color_brewer(name = 'Region',palette = 'Dark2') + 
  scale_linetype_manual(
    name = 'Region',
    values = c('solid','twodash','dashed','longdash','solid','dashed','solid')
  )+
  labs(
    x = 'Excess Load (GWh)',
    y = 'Share within 1pp capacity'
  ) +
  ylim(0,175) +
  facet_grid(cols = vars(ic), scales = 'free_x')+ 
  theme(
    legend.position = 'bottom',
    legend.key.width = unit(3, 'cm')
  )


  # Fuel mix by interconnection 
  ic_fuel_mix_p =
      sim_plant_dt[,.(
        tot_gen_mwh = sum(fit_net_generation_mwh)), 
        keyby = .(
          ic, 
          fuel_category = 
            fcase(
              fuel_category == 'natural_gas', 'Natural Gas',
              fuel_category == 'coal', 'Coal',
              fuel_category == 'hydro', 'Hydro',
              fuel_category == 'nuclear', 'Nuclear',
              #fuel_category == 'geothermal', 'Geothermal',
              default = 'Other'
            ) |>
            factor(
              levels = c('Natural Gas','Coal', 'Nuclear','Hydro','Other')
            ), 
          quantile,
          tot_eload
        )
      ][,pct_gen := tot_gen_mwh/sum(tot_gen_mwh), 
        by = .(ic, quantile)
      ] |>
      merge(
        ic_props, 
        by = 'ic'
      ) %>% 
      .[,':='(
        tot_eload_ic = tot_eload*prop_ic,
        prop_ic = NULL
      )]|>
      ggplot(aes(
        x = tot_eload_ic/1e3, 
        y = pct_gen, 
        color = fuel_category, 
        linetype = fuel_category
      )) + 
      geom_line(size = 1.05) +
      scale_color_brewer(name = 'Fuel Category',palette = 'Dark2') + 
      scale_y_continuous(labels = scales::label_percent()) +
      labs(
        x = 'Excess Load (GWh)',
        y = 'Percent of Electricity Production'
      ) + 
      scale_linetype_manual(
        name = 'Fuel Category',
        values = c('solid','longdash','twodash','dashed','dotted')
      ) +
      facet_wrap(~ic, scales = 'free_x')+ 
      guides(
        linetype=guide_legend(keywidth = 4, keyheight = 1),
        colour=guide_legend(keywidth = 4, keyheight = 1)
      ) + 
      theme(legend.position = 'bottom')
    ggsave(
      plot = ic_fuel_mix_p,
      filename = here('figures/electricity-generation/eload-sim-ic-fuel-mix.jpeg'),
      bg = 'white',
      width =11, height = 5
    )
    ggsave(
      plot = ic_fuel_mix_p,
      filename = here('figures/electricity-generation/eload-sim-ic-fuel-mix.pdf'),
      device = cairo_pdf,
      width =11, height = 5
    )
  


# Other plots -----------------------------------------------------------------
merge(
  region_sim_dt,
    quant_load_dt[
      square == FALSE & deviation == 'baseline', 
      .(tot_eload_ic = sum(excess_load)),
      by =.(
        quantile, 
        ic = fcase(
            toupper(nerc_adj) == 'TRE', 'Texas',
            toupper(nerc_adj) %in% c("CAL",'WECC'), 'West',
            default = 'East'
          ) |> factor(levels =c('West','Texas','East'))
        )
    ],
    by = c('ic','quantile')
  )


# First how does total production change? 
  region_gen_p = 
    ggplot(region_sim_dt, aes(
      x = tot_eload/1e3, 
      y = tot_gen_mwh/tot_eload, 
      color = nerc_adj
    )) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    #scale_x_continuous(labels = scales::label_percent())+
    labs(
      x = 'Excess Load (GWh)',
      y = 'Electricity Production (GWh)'
    )
  region_gen_p
  ggsave(
    plot = region_gen_p,
    filename = here('figures/electricity-generation/eload-sim-region-gen.jpeg'),
    bg = 'white',
    width = 8, height = 5
  )
  # How do damages change? 
  region_damages_p = 
    ggplot(region_sim_dt, aes(x =quantile, y = tot_damages/1e6, color = nerc_adj)) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_x_continuous(labels = scales::label_percent())+
    labs(
      x = 'Percentile of Excess Load',
      y = 'Damages ($1M)'
    )
  ggsave(
    plot = region_damages_p,
    filename = here('figures/electricity-generation/eload-sim-region-damages.jpeg'),
    bg = 'white',
    width = 8, height = 5
  )
  # Pollutants
  region_pollutants_p = 
    melt(
      region_sim_dt,
      id.vars = c('nerc_adj','quantile'),
      measure.vars = patterns('tons$')
    )[,pollutant := fcase(
      str_detect(variable, 'co2e'), 'CO2e',
      str_detect(variable, 'so2'), 'SO2',
      str_detect(variable, 'nox'), 'NOx',
      str_detect(variable, 'pm25'), 'PM 2.5'
    )] |>
    ggplot(aes(x = quantile, y = value, color = nerc_adj)) + 
    geom_line(size = 1.05) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_x_continuous(labels = scales::label_percent())+
    scale_y_continuous(labels = scales::comma_format()) +
    labs(
      x = 'Percentile of Excess Load',
      y = 'Emissions (Tons)'
    ) + 
    facet_wrap(~pollutant, scales = 'free')
  ggsave(
    plot = region_pollutants_p,
    filename = here('figures/electricity-generation/eload-sim-region-emissions.jpeg'),
    bg = 'white',
    width = 10, height = 7
  )
  # Marginal damages (take diff)
  region_md_p = 
    region_sim_dt[,.(
      quantile,
      tot_eload, 
      nerc_adj,
      ic,
      tot_damages,
      md_tot_load = c(NA, diff(tot_damages))/c(NA, diff(tot_eload)),
      md_region_load = c(NA, diff(tot_damages))/c(NA, diff(excess_load)),
      md_gen = c(NA, diff(tot_damages))/c(NA, diff(tot_gen_mwh)),
      delta_damages = c(NA, diff(tot_damages)),
      delta_gen = c(NA, diff(tot_gen_mwh))
    )][quantile != 0 & quantile != 1] |>
    ggplot(aes(x =tot_eload/1e3, y = md_gen, color = nerc_adj, linetype = nerc_adj)) + 
    geom_point(alpha = 0.15) +
    geom_smooth(se = FALSE, method = 'loess', span = 0.4) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_linetype_manual(
      name = 'Region',
      values = c('solid','solid','twodash','dashed','longdash','solid','dashed')
    )+
    labs(
      x = 'Excess Load (GWh)',
      y = 'Marginal Damages ($/MWh)'
    ) + 
    facet_grid(cols = vars(ic), scales = 'free_x') + 
    theme(
      legend.position = 'bottom',
      legend.key.width = unit(3, 'cm')
    )
  region_md_p
  ggsave(
    plot = region_md_p,
    filename = here('figures/electricity-generation/eload-sim-region-md.jpeg'),
    bg = 'white',
    width = 12, height = 5
  )
  ic_md_p = 
    ic_sim_dt[,.(
      quantile, 
      nerc_adj,
      tot_damages,
      md = c(NA, diff(tot_damages))/c(NA, diff(tot_gen_mwh)),
      delta_damages = c(NA, diff(tot_damages)),
      delta_gen = c(NA, diff(tot_gen_mwh))
    )][quantile != 0 & quantile != 1] |>
    ggplot(aes(x =quantile, y = md, color = nerc_adj, linetype = nerc_adj)) + 
    #geom_line(size = 1.25) +
    geom_point(alpha = 0.3) +
    geom_smooth(se = FALSE, method = 'loess') +
    scale_color_brewer(name = 'Interconnection',palette = 'Dark2') + 
    scale_x_continuous(labels = scales::label_percent())+
    ylim(20,70) +
    labs(
      x = 'Percentile of Excess Load',
      y = 'Marginal Damages ($/MWh)',
      linetype = 'Interconnection'
    ) + 
    guides(
      linetype=guide_legend(keywidth = 3, keyheight = 1),
      colour=guide_legend(keywidth = 3, keyheight = 1)
    ) + 
    theme(legend.position = 'bottom')
  ggsave(
    plot = ic_md_p,
    filename = here('figures/electricity-generation/eload-sim-ic-md.jpeg'),
    bg = 'white',
    width = 6, height = 5
  )
  ggsave(
    plot = ic_md_p,
    filename = here('figures/electricity-generation/eload-sim-ic-md.pdf'),
    device = cairo_pdf,
    width = 6, height = 5
  )
  # Average damages
  region_avg_damages_p = 
    ggplot(region_sim_dt, aes(x =quantile, y = tot_damages/tot_gen_mwh, color = nerc_adj)) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    scale_x_continuous(labels = scales::label_percent())+
    labs(
      x = 'Percentile of Excess Load',
      y = 'Damages per MWh'
    )
  ggsave(
    plot = region_avg_damages_p,
    filename = here('figures/electricity-generation/eload-sim-region-damages.jpeg'),
    bg = 'white',
    width = 8, height = 5
  )
  # Fuel mix 
  region_fuel_mix_p =
    merge(
      sim_plant_dt,
      plant_info_dt, 
      by = 'plant_id_eia'
    ) |>
    merge9(
      prop_load_dt[,.(nerc_adj, quantil, tot_eload)],
      by = c('nerc_adj','quantile')
    )[,.(tot_gen_mwh = sum(fit_net_generation_mwh)), 
      keyby = .(
        nerc_adj, 
        fuel_category = 
          fcase(
            fuel_category == 'natural_gas', 'Natural Gas',
            fuel_category == 'coal', 'Coal',
            fuel_category == 'hydro', 'Hydro',
            fuel_category == 'geothermal', 'Geothermal',
            fuel_category == 'nuclear', 'Nuclear',
            default = 'Other'
          ) |>
          factor(
            levels = c('Natural Gas','Coal', 'Nuclear','Hydro','Geothermal','Other')
          ), 
        quantile
      )
    ][,pct_gen := tot_gen_mwh/sum(tot_gen_mwh), 
      by = .(nerc_adj, quantile)
    ] |>
    ggplot(aes(x = quantile, y = pct_gen, color = nerc_adj)) + 
      geom_line(size = 1.05) +
      scale_color_brewer(name = 'Region',palette = 'Dark2') + 
      scale_x_continuous(labels = scales::label_percent())+
      scale_y_continuous(labels = scales::label_percent()) +
      labs(
        x = 'Percentile of Excess Load',
        y = 'Percent of Electricity Production'
      ) + 
      facet_wrap(~fuel_category)
    ggsave(
      plot = region_fuel_mix_p,
      filename = here('figures/electricity-generation/eload-sim-region-fuel-mix.jpeg'),
      bg = 'white',
      width = 10, height = 7
    )
  # by interconnection 
    
    
    
    
    